-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM #140573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/momchil-velikov/vector-contract-i8mm-transform-dialect
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesPatch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff 5 Files Affected:
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17],
+ [-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999, 1941, 685, -2879 )
+// CHECK: ( -3705, 2952, 987, -685 )
+// CHECK: ( 2565, 4157, -1589, -357 )
+// CHECK: ( 2383, -2252, 32, -1365 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c256 = arith.constant 256 : i32
+ func.call @setArmVLBits(%c256) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
+ [-20, -36, -3, 39, -48, -31, -25, -21],
+ [-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49],
+ [-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32>
+ %acc_m = memref.alloca() : memref<8x8xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+ vector.print %acc4 : vector<[4]xi32>
+ vector.print %acc5 : vector<[4]xi32>
+ vector.print %acc6 : vector<[4]xi32>
+ vector.print %acc7 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38],
+ [ -3, 9, 43, -30, -32, 39, 41, -39],
+ [-13, -21, -25, 27, 47, -36, -11, -11],
+ [ -4, -20, 36, 11, 13, -23, 24, -13],
+ [-20, 30, -5, 1, 42, -37, -22, 35],
+ [-22, 38, -4, 44, 25, -31, 23, -39],
+ [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8>
+
+ %lhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+ %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+ %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+ %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+ %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+ vector.print %lhs4 : vector<8xi8>
+ vector.print %lhs5 : vector<8xi8>
+ vector.print %lhs6 : vector<8xi8>
+ vector.print %lhs7 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32],
+ [ 46, -45, -48, -46, -24, 31, -36, 22],
+ [ 2, 36, 45, -29, -37, -49, -20, -35],
+ [ -6, 23, 23, 15, 20, 4, -8, -2],
+ [-35, -6, 16, 49, -50, 9, -44, 13],
+ [ 24, 1, -4, -44, 41, 15, -43, 44],
+ [ 44, 0, -10, 41, 22, 44, -40, 0],
+ [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8>
+
+ %rhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+ // Display the result of the multilication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+ vector.print %u4 : vector<[4]xi32>
+ vector.print %u5 : vector<[4]xi32>
+ vector.print %u6 : vector<[4]xi32>
+ vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
+// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
+// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )
+// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 )
+// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 )
+// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 )
+// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 )
+// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99],
+ [221, 25, 164, 97, 156, 221, 218, 177],
+ [171, 160, 219, 191, 144, 45, 161, 210],
+ [223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: ( -7613, -8386, -15938, -6521 )
+// CHECK: ( 9468, 18750, 9199, 5764 )
+// CHECK: ( 33655, 41064, 48900, 31627 )
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesPatch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff 5 Files Affected:
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17],
+ [-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999, 1941, 685, -2879 )
+// CHECK: ( -3705, 2952, 987, -685 )
+// CHECK: ( 2565, 4157, -1589, -357 )
+// CHECK: ( 2383, -2252, 32, -1365 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c256 = arith.constant 256 : i32
+ func.call @setArmVLBits(%c256) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
+ [-20, -36, -3, 39, -48, -31, -25, -21],
+ [-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49],
+ [-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32>
+ %acc_m = memref.alloca() : memref<8x8xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+ vector.print %acc4 : vector<[4]xi32>
+ vector.print %acc5 : vector<[4]xi32>
+ vector.print %acc6 : vector<[4]xi32>
+ vector.print %acc7 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38],
+ [ -3, 9, 43, -30, -32, 39, 41, -39],
+ [-13, -21, -25, 27, 47, -36, -11, -11],
+ [ -4, -20, 36, 11, 13, -23, 24, -13],
+ [-20, 30, -5, 1, 42, -37, -22, 35],
+ [-22, 38, -4, 44, 25, -31, 23, -39],
+ [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8>
+
+ %lhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+ %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+ %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+ %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+ %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+ vector.print %lhs4 : vector<8xi8>
+ vector.print %lhs5 : vector<8xi8>
+ vector.print %lhs6 : vector<8xi8>
+ vector.print %lhs7 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32],
+ [ 46, -45, -48, -46, -24, 31, -36, 22],
+ [ 2, 36, 45, -29, -37, -49, -20, -35],
+ [ -6, 23, 23, 15, 20, 4, -8, -2],
+ [-35, -6, 16, 49, -50, 9, -44, 13],
+ [ 24, 1, -4, -44, 41, 15, -43, 44],
+ [ 44, 0, -10, 41, 22, 44, -40, 0],
+ [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8>
+
+ %rhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+ // Display the result of the multilication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+ vector.print %u4 : vector<[4]xi32>
+ vector.print %u5 : vector<[4]xi32>
+ vector.print %u6 : vector<[4]xi32>
+ vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
+// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
+// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )
+// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 )
+// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 )
+// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 )
+// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 )
+// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99],
+ [221, 25, 164, 97, 156, 221, 218, 177],
+ [171, 160, 219, 191, 144, 45, 161, 210],
+ [223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: ( -7613, -8386, -15938, -6521 )
+// CHECK: ( 9468, 18750, 9199, 5764 )
+// CHECK: ( 33655, 41064, 48900, 31627 )
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]
|
194c1c7
to
87f2964
Compare
251f93e
to
4ab7765
Compare
Thanks - great to finally be reaching this stage! I have a few high-level questions and suggestions: 1. Why is the scalable dimension always [4]? From the current tests, it looks like the scalable dim is always 2. Reduce duplication in the 4x8x4 tests The current tests differ only in terms of input/output and For example: func.func @main_smmla() {
// Init LHS, RHS, ACC
// CHECK-LINES for LHS
print(lhs);
// CHECK-LINES for RHS
print(rhs);
arith.extsi (lhs)
arith.extsi (rhs)
vector.contract
// CHECK-LINES for ACC
print(acc);
} This would keep the test logic focused and easier to maintain. 3. Add checks for generated IR (LLVM dialect) It would be good to verify that the lowered IR includes the correct SME MMLA intrinsics. For example: // CHECK-COUNT-4: llvm.intr.smmla This would help confirm both correctness and that the expected number of operations are emitted. 4. Consider toggling VL within tests From what I can tell, this would only work if the inputs are generated inside a loop, similar to this example: llvm-project/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir Lines 19 to 37 in 88f61f2
That might be a nice validation of the "scalability" aspect. |
No description provided.